{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 2: Intro to Federated Learning\n",
    "\n",
    "In the last section, we learned about PointerTensors, which create the underlying infrastructure we need for privacy preserving Deep Learning. In this section, we're going to see how to use these basic tools to implement our first privacy preserving deep learning algorithm, Federated Learning.\n",
    "\n",
    "Authors:\n",
    "- Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)\n",
    "\n",
    "### What is Federated Learning?\n",
    "\n",
    "It's a simple, powerful way to train Deep Learning models. If you think about training data, it's always the result of some sort of collection process. People (via devices) generate data by recording events in the real world. Normally, this data is aggregated to a single, central location so that you can train a machine learning model. Federated Learning turns this on its head!\n",
    "\n",
    "Instead of bringing training data to the model (a central server), you bring the model to the training data (wherever it may live).\n",
    "\n",
    "The idea is that this allows whoever is creating the data to own the only permanent copy, and thus maintain control over who ever has access to it. Pretty cool, eh?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Section 2.1 - A Toy Federated Learning Example\n",
    "\n",
    "Let's start by training a toy model the centralized way. This is about a simple as models get. We first need:\n",
    "\n",
    "- a toy dataset\n",
    "- a model\n",
    "- some basic training logic for training a model to fit the data.\n",
    "\n",
    "Note: If this API is un-familiar to you - head on over to [fast.ai](http://fast.ai) and take their course before continuing in this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# A Toy Dataset\n",
    "data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)\n",
    "target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)\n",
    "\n",
    "# A Toy Model\n",
    "model = nn.Linear(2,1)\n",
    "\n",
    "def train():\n",
    "    # Training Logic\n",
    "    opt = optim.SGD(params=model.parameters(),lr=0.1)\n",
    "    for iter in range(20):\n",
    "\n",
    "        # 1) erase previous gradients (if they exist)\n",
    "        opt.zero_grad()\n",
    "\n",
    "        # 2) make a prediction\n",
    "        pred = model(data)\n",
    "\n",
    "        # 3) calculate how much we missed\n",
    "        loss = ((pred - target)**2).sum()\n",
    "\n",
    "        # 4) figure out which weights caused us to miss\n",
    "        loss.backward()\n",
    "\n",
    "        # 5) change those weights\n",
    "        opt.step()\n",
    "\n",
    "        # 6) print our progress\n",
    "        print(loss.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And there you have it! We've trained a basic model in the conventional manner. All our data is aggregated into our local machine and we can use it to make updates to our model. Federated Learning, however, doesn't work this way. So, let's modify this example to do it the Federated Learning way! \n",
    "\n",
    "So, what do we need:\n",
    "\n",
    "- create a couple workers\n",
    "- get pointers to training data on each worker\n",
    "- updated training logic to do federated learning\n",
    "\n",
    "    New Training Steps:\n",
    "    - send model to correct worker\n",
    "    - train on the data located there\n",
    "    - get the model back and repeat with next worker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy\n",
    "hook = sy.TorchHook(torch)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a couple workers\n",
    "\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A Toy Dataset\n",
    "data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)\n",
    "target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)\n",
    "\n",
    "# get pointers to training data on each worker by\n",
    "# sending some training data to bob and alice\n",
    "data_bob = data[0:2]\n",
    "target_bob = target[0:2]\n",
    "\n",
    "data_alice = data[2:]\n",
    "target_alice = target[2:]\n",
    "\n",
    "# Iniitalize A Toy Model\n",
    "model = nn.Linear(2,1)\n",
    "\n",
    "data_bob = data_bob.send(bob)\n",
    "data_alice = data_alice.send(alice)\n",
    "target_bob = target_bob.send(bob)\n",
    "target_alice = target_alice.send(alice)\n",
    "\n",
    "# organize pointers into a list\n",
    "datasets = [(data_bob,target_bob),(data_alice,target_alice)]\n",
    "\n",
    "opt = optim.SGD(params=model.parameters(),lr=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    # Training Logic\n",
    "    opt = optim.SGD(params=model.parameters(),lr=0.1)\n",
    "    for iter in range(10):\n",
    "        \n",
    "        # NEW) iterate through each worker's dataset\n",
    "        for data,target in datasets:\n",
    "            \n",
    "            # NEW) send model to correct worker\n",
    "            model.send(data.location)\n",
    "\n",
    "            # 1) erase previous gradients (if they exist)\n",
    "            opt.zero_grad()\n",
    "\n",
    "            # 2) make a prediction\n",
    "            pred = model(data)\n",
    "\n",
    "            # 3) calculate how much we missed\n",
    "            loss = ((pred - target)**2).sum()\n",
    "\n",
    "            # 4) figure out which weights caused us to miss\n",
    "            loss.backward()\n",
    "\n",
    "            # 5) change those weights\n",
    "            opt.step()\n",
    "            \n",
    "            # NEW) get model (with gradients)\n",
    "            model.get()\n",
    "\n",
    "            # 6) print our progress\n",
    "            print(loss.get()) # NEW) slight edit... need to call .get() on loss\\\n",
    "    \n",
    "# federated averaging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Well Done!\n",
    "\n",
    "And voilà! We now are training a very simple Deep Learning model using Federated Learning! We send the model to each worker, generate a new gradient, and then bring the gradient back to our local server where we update our global model. Never in this process do we ever see or request access to the underlying training data! We preserve the privacy of Bob and Alice!!!\n",
    "\n",
    "## Shortcomings of this Example\n",
    "\n",
    "So, while this example is a nice introduction to Federated Learning, it still has some major shortcomings. Most notably, when we call `model.get()` and receive the updated model from Bob or Alice, we can actually learn a lot about Bob and Alice's training data by looking at their gradients. In some cases, we can restore their training data perfectly! \n",
    "\n",
    "So, what is there to do? Well, the first strategy people employ is to **average the gradient across multiple individuals before uploading it to the central server**. This strategy, however, will require some more sophisticated use of PointerTensor objects. So, in the next section, we're going to take some time to learn about more advanced pointer functionality and then we'll upgrade this Federated Learning example\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for GitHub issues marked \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
